balanced_accuracy_score#

Balanced accuracy is the macro-average of recall: it computes recall separately for each class and then averages across classes.

Quick import#

from sklearn.metrics import balanced_accuracy_score

It is especially useful when:

  • the dataset is imbalanced

  • you want to treat classes equally (e.g., you care about minority recall as much as majority recall)

Goals

  • Build intuition (why accuracy can be misleading)

  • Derive the metric for binary and multiclass classification

  • Implement balanced_accuracy_score from scratch in NumPy

  • Visualize per-class recall and threshold effects (Plotly)

  • Use balanced accuracy to guide a simple optimization loop (from-scratch logistic regression)

Prerequisites

  • Confusion matrix, recall (TPR), specificity (TNR)

  • Probabilistic classifiers (logistic regression outputs probabilities)

import numpy as np

import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio
from plotly.subplots import make_subplots

from scipy.special import expit

from sklearn.linear_model import LogisticRegression
from sklearn.metrics import balanced_accuracy_score as sk_balanced_accuracy_score
from sklearn.model_selection import train_test_split

pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
pio.templates.default = "plotly_white"

rng = np.random.default_rng(42)
np.set_printoptions(precision=4, suppress=True)

1) Why not plain accuracy?#

Accuracy is

\[ \text{Accuracy} = \frac{\#\{i : \hat{y}_i = y_i\}}{n}. \]

If one class dominates, a model can achieve high accuracy by mostly predicting the majority class.

Example: 99% negatives, 1% positives.

  • A classifier that predicts always negative gets 99% accuracy.

  • But it has 0% recall on the positive class.

Balanced accuracy fixes this by computing recall per class and averaging them, so each class contributes equally.

2) Definition (binary classification)#

For \(y \in \{0,1\}\) and predictions \(\hat{y} \in \{0,1\}\):

\(\hat{y}=0\)

\(\hat{y}=1\)

\(y=0\)

TN

FP

\(y=1\)

FN

TP

Two key rates:

  • Recall / sensitivity / TPR

\[ \text{TPR} = \frac{\text{TP}}{\text{TP} + \text{FN}} \]
  • Specificity / TNR

\[ \text{TNR} = \frac{\text{TN}}{\text{TN} + \text{FP}} \]

Balanced accuracy is the mean of these two:

\[ \text{BA} = \frac{1}{2}(\text{TPR} + \text{TNR}). \]

It is also related to the balanced error rate (BER):

\[ \text{BER} = 1 - \text{BA} = \tfrac{1}{2}(\text{FNR} + \text{FPR}). \]

3) Definition (multiclass + adjusted)#

For \(K\) classes, balanced accuracy is the average recall per class (macro recall):

\[ \text{BA} = \frac{1}{K} \sum_{k=1}^K \text{Recall}_k \qquad\text{where}\qquad \text{Recall}_k = \frac{\text{TP}_k}{\text{TP}_k + \text{FN}_k}. \]

So balanced accuracy is exactly:

\[ \text{BA} = \texttt{recall\_score}(\text{average}="macro"). \]

Adjusted balanced accuracy#

scikit-learn also offers a chance-corrected version:

\[ \text{BA}_{\text{adj}} = \frac{\text{BA} - 1/K}{1 - 1/K}. \]
  • A classifier that effectively behaves like random guessing tends toward \(\text{BA}_{\text{adj}} \approx 0\).

  • Perfect classification gives \(\text{BA}_{\text{adj}} = 1\).

  • Worse-than-chance can be negative.

def accuracy_score_np(y_true, y_pred, sample_weight=None) -> float:
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)
    correct = (y_true == y_pred).astype(float)

    if sample_weight is None:
        return float(correct.mean())

    w = np.asarray(sample_weight, dtype=float)
    return float(np.sum(w * correct) / np.sum(w))


def per_class_recall_np(
    y_true,
    y_pred,
    labels=None,
    sample_weight=None,
    zero_division: float = 0.0,
):
    # Per-class recall:
    #   recall_k = (# predicted as k among true k) / (# true k)
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    if labels is None:
        labels = np.unique(y_true)
    labels = np.asarray(labels)

    if sample_weight is None:
        sample_weight = np.ones_like(y_true, dtype=float)
    else:
        sample_weight = np.asarray(sample_weight, dtype=float)

    recalls = np.empty(len(labels), dtype=float)

    for i, cls in enumerate(labels):
        mask = y_true == cls
        denom = float(sample_weight[mask].sum())
        if denom == 0.0:
            recalls[i] = zero_division
        else:
            num = float(sample_weight[mask & (y_pred == cls)].sum())
            recalls[i] = num / denom

    return recalls, labels


def balanced_accuracy_score_np(
    y_true,
    y_pred,
    *,
    labels=None,
    sample_weight=None,
    adjusted: bool = False,
    zero_division: float = 0.0,
) -> float:
    recalls, labels_used = per_class_recall_np(
        y_true,
        y_pred,
        labels=labels,
        sample_weight=sample_weight,
        zero_division=zero_division,
    )
    score = float(np.mean(recalls))

    if not adjusted:
        return score

    n_classes = len(labels_used)
    if n_classes <= 1:
        return 1.0

    chance = 1.0 / n_classes
    return float((score - chance) / (1.0 - chance))


def confusion_matrix_np(y_true, y_pred, labels=None, sample_weight=None):
    # Small confusion-matrix helper (mainly for plotting)
    y_true = np.asarray(y_true)
    y_pred = np.asarray(y_pred)

    if labels is None:
        labels = np.unique(np.concatenate([y_true, y_pred]))
    labels = np.asarray(labels)

    label_to_index = {label: i for i, label in enumerate(labels)}

    true_idx = np.array([label_to_index.get(v, -1) for v in y_true], dtype=int)
    pred_idx = np.array([label_to_index.get(v, -1) for v in y_pred], dtype=int)

    if sample_weight is None:
        sample_weight = np.ones_like(true_idx, dtype=float)
    else:
        sample_weight = np.asarray(sample_weight, dtype=float)

    cm = np.zeros((len(labels), len(labels)), dtype=float)
    valid = (true_idx >= 0) & (pred_idx >= 0)
    np.add.at(cm, (true_idx[valid], pred_idx[valid]), sample_weight[valid])

    return cm, labels


# quick sanity check vs scikit-learn
_y_true = np.array([0, 0, 0, 1, 1, 1])
_y_pred = np.array([0, 0, 1, 0, 1, 1])
print('ours:', balanced_accuracy_score_np(_y_true, _y_pred))
print('sklearn:', sk_balanced_accuracy_score(_y_true, _y_pred))
ours: 0.6666666666666666
sklearn: 0.6666666666666666

4) Worked example: “always predict the majority class”#

Let’s build an extremely imbalanced dataset and evaluate a trivial classifier.

  • 990 negatives (class 0)

  • 10 positives (class 1)

  • predictions: always class 0

This classifier gets excellent accuracy, but poor minority-class performance.

n_neg, n_pos = 990, 10

y_true = np.array([0] * n_neg + [1] * n_pos)
y_pred = np.zeros_like(y_true)

acc = accuracy_score_np(y_true, y_pred)
bal = balanced_accuracy_score_np(y_true, y_pred)
bal_adj = balanced_accuracy_score_np(y_true, y_pred, adjusted=True)
recalls, labels = per_class_recall_np(y_true, y_pred)

print(f"accuracy:          {acc:.4f}")
print(f"balanced accuracy: {bal:.4f}")
print(f"adjusted BA:       {bal_adj:.4f}")
print("per-class recall:", dict(zip(labels.tolist(), recalls.tolist())))
accuracy:          0.9900
balanced accuracy: 0.5000
adjusted BA:       0.0000
per-class recall: {0: 1.0, 1: 0.0}
cm, cm_labels = confusion_matrix_np(y_true, y_pred)

fig = px.imshow(
    cm,
    text_auto=True,
    color_continuous_scale="Blues",
    x=[f"pred={l}" for l in cm_labels],
    y=[f"true={l}" for l in cm_labels],
)
fig.update_layout(title="Confusion matrix: always predicting class 0")
fig.show()

fig = go.Figure(
    data=[
        go.Bar(
            x=[str(l) for l in labels],
            y=recalls,
            text=[f"{r:.2f}" for r in recalls],
            textposition="auto",
        )
    ]
)
fig.update_layout(
    title="Per-class recall (balanced accuracy is the mean of these)",
    xaxis_title="class",
    yaxis_title="recall",
    yaxis=dict(range=[0, 1]),
)
fig.show()

5) Threshold dependence (probabilities → labels)#

Balanced accuracy is defined on hard predictions (\(\hat{y}\)).

If your model outputs probabilities \(p(x) = P(y=1\mid x)\), you still need a decision threshold \(t\):

\[ \hat{y} = \mathbb{1}[p(x) \ge t]. \]

Changing \(t\) changes the confusion matrix, hence recall per class, hence balanced accuracy.

# A simple probability simulation (overlapping scores + class imbalance)

n_neg, n_pos = 2000, 100

y_true = np.array([0] * n_neg + [1] * n_pos)

# Negatives tend to have lower predicted probabilities, positives higher, but overlapping.
p_neg = rng.beta(2.0, 8.0, size=n_neg)
p_pos = rng.beta(5.0, 5.0, size=n_pos)

proba = np.concatenate([p_neg, p_pos])

# Shuffle together
perm = rng.permutation(len(y_true))
y_true = y_true[perm]
proba = proba[perm]

thresholds = np.linspace(0.0, 1.0, 401)
accs = np.empty_like(thresholds)
bals = np.empty_like(thresholds)

for i, t in enumerate(thresholds):
    y_pred = (proba >= t).astype(int)
    accs[i] = accuracy_score_np(y_true, y_pred)
    bals[i] = balanced_accuracy_score_np(y_true, y_pred)

best_t = float(thresholds[np.argmax(bals)])

fig = go.Figure()
fig.add_trace(go.Scatter(x=thresholds, y=accs, name="accuracy", mode="lines"))
fig.add_trace(go.Scatter(x=thresholds, y=bals, name="balanced accuracy", mode="lines"))
fig.add_vline(x=best_t, line_dash="dash", line_color="black")
fig.update_layout(
    title=f"Accuracy vs balanced accuracy as a function of threshold (best BA at t={best_t:.3f})",
    xaxis_title="threshold t",
    yaxis_title="score",
    yaxis=dict(range=[0, 1]),
)
fig.show()

6) Using balanced accuracy to guide an optimization loop (logistic regression)#

Balanced accuracy is not differentiable w.r.t. model parameters because it depends on discrete decisions (argmax / threshold).

In practice, we typically:

  1. Train a probabilistic classifier with a differentiable loss (e.g., log loss)

  2. Use balanced accuracy as a model selection criterion:

    • choose hyperparameters

    • choose early-stopping epoch

    • choose decision threshold

A common surrogate that often improves balanced accuracy is to train with class weights (roughly: make each class contribute equally to the loss).

Below we train logistic regression from scratch in two ways:

  • Unweighted log loss

  • Class-weighted log loss (“balanced” weights)

…and monitor validation balanced accuracy for early stopping.

# Synthetic 2D imbalanced dataset (mild overlap)

n0, n1 = 1200, 80

X0 = rng.normal(loc=(0.0, 0.0), scale=1.0, size=(n0, 2))
X1 = rng.normal(loc=(1.2, 1.2), scale=1.0, size=(n1, 2))

X = np.vstack([X0, X1])
y = np.concatenate([np.zeros(n0, dtype=int), np.ones(n1, dtype=int)])

perm = rng.permutation(len(y))
X, y = X[perm], y[perm]

X_train, X_val, y_train, y_val = train_test_split(
    X, y, test_size=0.25, random_state=0, stratify=y
)

fig = px.scatter(
    x=X[:, 0],
    y=X[:, 1],
    color=y.astype(str),
    opacity=0.7,
    title="Synthetic imbalanced dataset",
    labels={"x": "x1", "y": "x2", "color": "class"},
)
fig.show()

print('train class counts:', {0: int((y_train==0).sum()), 1: int((y_train==1).sum())})
print('val class counts:  ', {0: int((y_val==0).sum()), 1: int((y_val==1).sum())})
train class counts: {0: 900, 1: 60}
val class counts:   {0: 300, 1: 20}
def standardize_fit(X):
    mean = X.mean(axis=0)
    std = X.std(axis=0) + 1e-12
    return mean, std


def standardize_transform(X, mean, std):
    return (X - mean) / std


def add_intercept(X):
    return np.c_[np.ones((X.shape[0], 1)), X]


def predict_proba_logreg(X, w):
    Xb = add_intercept(X)
    return expit(Xb @ w)


def log_loss_binary(y, p, sample_weight=None, eps: float = 1e-12) -> float:
    y = np.asarray(y)
    p = np.clip(np.asarray(p), eps, 1.0 - eps)

    per_sample = -(y * np.log(p) + (1.0 - y) * np.log(1.0 - p))

    if sample_weight is None:
        return float(per_sample.mean())

    w = np.asarray(sample_weight, dtype=float)
    return float(np.sum(w * per_sample) / np.sum(w))


def fit_logreg_gd(
    X_train,
    y_train,
    X_val,
    y_val,
    *,
    lr: float = 0.2,
    n_epochs: int = 400,
    l2: float = 1e-2,
    sample_weight=None,
):
    # Binary logistic regression with (optional) sample weights + early stopping on val BA
    Xb = add_intercept(X_train)
    n, d = Xb.shape

    if sample_weight is None:
        sample_weight = np.ones(n, dtype=float)
    else:
        sample_weight = np.asarray(sample_weight, dtype=float)

    sw_sum = float(sample_weight.sum())
    w = np.zeros(d, dtype=float)

    history = {
        "train_loss": [],
        "val_acc": [],
        "val_bal_acc": [],
    }

    best = {
        "epoch": -1,
        "val_bal_acc": -np.inf,
        "w": w.copy(),
    }

    for epoch in range(n_epochs):
        # forward + gradient on train
        p_train = expit(Xb @ w)
        grad = (Xb.T @ (sample_weight * (p_train - y_train))) / sw_sum
        grad[1:] += l2 * w[1:]

        w = w - lr * grad

        # metrics
        p_train = expit(Xb @ w)
        train_loss = log_loss_binary(y_train, p_train, sample_weight=sample_weight) + 0.5 * l2 * float(
            np.sum(w[1:] ** 2)
        )

        p_val = predict_proba_logreg(X_val, w)
        y_val_hat = (p_val >= 0.5).astype(int)

        val_acc = accuracy_score_np(y_val, y_val_hat)
        val_bal_acc = balanced_accuracy_score_np(y_val, y_val_hat)

        history["train_loss"].append(train_loss)
        history["val_acc"].append(val_acc)
        history["val_bal_acc"].append(val_bal_acc)

        if val_bal_acc > best["val_bal_acc"]:
            best = {"epoch": epoch, "val_bal_acc": val_bal_acc, "w": w.copy()}

    return best["w"], history, best
# Standardize features (important for GD stability)
mean, std = standardize_fit(X_train)
X_train_s = standardize_transform(X_train, mean, std)
X_val_s = standardize_transform(X_val, mean, std)

# Unweighted training
w_unw, hist_unw, best_unw = fit_logreg_gd(X_train_s, y_train, X_val_s, y_val)

# Balanced class weights: each class gets ~50% of total weight
n_train = len(y_train)
n_pos = int((y_train == 1).sum())
n_neg = int((y_train == 0).sum())

w_pos = n_train / (2.0 * n_pos)
w_neg = n_train / (2.0 * n_neg)
sw_bal = np.where(y_train == 1, w_pos, w_neg)

w_wt, hist_wt, best_wt = fit_logreg_gd(X_train_s, y_train, X_val_s, y_val, sample_weight=sw_bal)

print('best epoch (unweighted):', best_unw['epoch'], 'val BA:', f"{best_unw['val_bal_acc']:.4f}")
print('best epoch (weighted):  ', best_wt['epoch'], 'val BA:', f"{best_wt['val_bal_acc']:.4f}")
best epoch (unweighted): 179 val BA: 0.5250
best epoch (weighted):   53 val BA: 0.8367
epochs = np.arange(1, len(hist_unw["train_loss"]) + 1)

fig = make_subplots(
    rows=1,
    cols=3,
    subplot_titles=("Train log loss", "Validation accuracy", "Validation balanced accuracy"),
)

for name, hist in [("unweighted", hist_unw), ("class-weighted", hist_wt)]:
    fig.add_trace(
        go.Scatter(x=epochs, y=hist["train_loss"], name=f"{name} loss", mode="lines"),
        row=1,
        col=1,
    )
    fig.add_trace(
        go.Scatter(x=epochs, y=hist["val_acc"], name=f"{name} acc", mode="lines"),
        row=1,
        col=2,
    )
    fig.add_trace(
        go.Scatter(x=epochs, y=hist["val_bal_acc"], name=f"{name} BA", mode="lines"),
        row=1,
        col=3,
    )

fig.update_layout(height=350, width=1100, title="Training curves (early stopping uses validation BA)")
fig.update_yaxes(range=[0, 1], row=1, col=2)
fig.update_yaxes(range=[0, 1], row=1, col=3)
fig.show()
def best_threshold_for_balanced_accuracy(y_true, proba, thresholds):
    best = {"t": None, "ba": -np.inf}
    for t in thresholds:
        y_pred = (proba >= t).astype(int)
        ba = balanced_accuracy_score_np(y_true, y_pred)
        if ba > best["ba"]:
            best = {"t": float(t), "ba": float(ba)}
    return best


thresholds = np.linspace(0.0, 1.0, 401)

p_unw = predict_proba_logreg(X_val_s, w_unw)
p_wt = predict_proba_logreg(X_val_s, w_wt)

best_t_unw = best_threshold_for_balanced_accuracy(y_val, p_unw, thresholds)
best_t_wt = best_threshold_for_balanced_accuracy(y_val, p_wt, thresholds)

print('best threshold (unweighted):', best_t_unw)
print('best threshold (weighted):  ', best_t_wt)

# Visualize BA(t)
ba_unw = [balanced_accuracy_score_np(y_val, (p_unw >= t).astype(int)) for t in thresholds]
ba_wt = [balanced_accuracy_score_np(y_val, (p_wt >= t).astype(int)) for t in thresholds]

fig = go.Figure()
fig.add_trace(go.Scatter(x=thresholds, y=ba_unw, name="unweighted", mode="lines"))
fig.add_trace(go.Scatter(x=thresholds, y=ba_wt, name="class-weighted", mode="lines"))
fig.add_vline(x=best_t_unw["t"], line_dash="dash", line_color="#1f77b4")
fig.add_vline(x=best_t_wt["t"], line_dash="dash", line_color="#ff7f0e")
fig.update_layout(
    title="Validation balanced accuracy as a function of the decision threshold",
    xaxis_title="threshold t",
    yaxis_title="balanced accuracy",
    yaxis=dict(range=[0, 1]),
)
fig.show()
best threshold (unweighted): {'t': 0.1525, 'ba': 0.8433333333333334}
best threshold (weighted):   {'t': 0.62, 'ba': 0.8416666666666667}
def summarize_threshold(y_true, proba, t):
    y_pred = (proba >= t).astype(int)
    acc = accuracy_score_np(y_true, y_pred)
    ba = balanced_accuracy_score_np(y_true, y_pred)
    recalls, labels = per_class_recall_np(y_true, y_pred)
    cm, _ = confusion_matrix_np(y_true, y_pred, labels=np.array([0, 1]))
    return {
        "t": float(t),
        "acc": float(acc),
        "ba": float(ba),
        "recalls": dict(zip(labels.tolist(), recalls.tolist())),
        "cm": cm,
    }


summaries = {
    "unweighted @0.5": summarize_threshold(y_val, p_unw, 0.5),
    "unweighted @t*": summarize_threshold(y_val, p_unw, best_t_unw["t"]),
    "weighted @0.5": summarize_threshold(y_val, p_wt, 0.5),
    "weighted @t*": summarize_threshold(y_val, p_wt, best_t_wt["t"]),
}

for k, v in summaries.items():
    print(k, {"t": v["t"], "acc": v["acc"], "ba": v["ba"], "recalls": v["recalls"]})

# Confusion matrices (2x2): rows=methods, cols=threshold choice
fig = make_subplots(
    rows=2,
    cols=2,
    subplot_titles=(
        "Unweighted @0.5",
        "Unweighted @t*",
        "Weighted @0.5",
        "Weighted @t*",
    ),
)

items = [
    (1, 1, summaries["unweighted @0.5"]),
    (1, 2, summaries["unweighted @t*"]),
    (2, 1, summaries["weighted @0.5"]),
    (2, 2, summaries["weighted @t*"]),
]

for r, c, s in items:
    cm = s["cm"]
    fig.add_trace(
        go.Heatmap(
            z=cm,
            x=["pred=0", "pred=1"],
            y=["true=0", "true=1"],
            colorscale="Blues",
            showscale=False,
            text=cm.astype(int),
            texttemplate="%{text}",
        ),
        row=r,
        col=c,
    )

fig.update_layout(height=650, width=900, title="Validation confusion matrices")
fig.show()
unweighted @0.5 {'t': 0.5, 'acc': 0.940625, 'ba': 0.525, 'recalls': {0: 1.0, 1: 0.05}}
unweighted @t* {'t': 0.1525, 'acc': 0.88125, 'ba': 0.8433333333333334, 'recalls': {0: 0.8866666666666667, 1: 0.8}}
weighted @0.5 {'t': 0.5, 'acc': 0.7375, 'ba': 0.8366666666666667, 'recalls': {0: 0.7233333333333334, 1: 0.95}}
weighted @t* {'t': 0.62, 'acc': 0.834375, 'ba': 0.8416666666666667, 'recalls': {0: 0.8333333333333334, 1: 0.85}}
# Decision boundary visualization (in original feature space)

def decision_boundary_figure(X_val, y_val, w, mean, std, threshold: float, title: str):
    x1_min, x1_max = X_val[:, 0].min() - 1.0, X_val[:, 0].max() + 1.0
    x2_min, x2_max = X_val[:, 1].min() - 1.0, X_val[:, 1].max() + 1.0

    xs = np.linspace(x1_min, x1_max, 200)
    ys = np.linspace(x2_min, x2_max, 200)
    xx, yy = np.meshgrid(xs, ys)
    grid = np.c_[xx.ravel(), yy.ravel()]
    grid_s = standardize_transform(grid, mean, std)

    p = predict_proba_logreg(grid_s, w).reshape(xx.shape)

    fig = go.Figure()

    fig.add_trace(
        go.Contour(
            x=xs,
            y=ys,
            z=p,
            contours=dict(start=threshold, end=threshold, size=1, coloring="lines"),
            line=dict(color="black", width=3),
            showscale=False,
            name="decision boundary",
        )
    )

    fig.add_trace(
        go.Scatter(
            x=X_val[:, 0],
            y=X_val[:, 1],
            mode="markers",
            marker=dict(
                size=6,
                color=y_val,
                colorscale=[[0, "#1f77b4"], [1, "#d62728"]],
                opacity=0.7,
                line=dict(width=0),
            ),
            name="validation points",
        )
    )

    fig.update_layout(
        title=title,
        xaxis_title="x1",
        yaxis_title="x2",
        height=450,
        width=500,
        legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1),
    )

    return fig


fig1 = decision_boundary_figure(
    X_val,
    y_val,
    w_unw,
    mean,
    std,
    threshold=best_t_unw["t"],
    title=f"Unweighted logistic regression (threshold t*={best_t_unw['t']:.2f})",
)
fig2 = decision_boundary_figure(
    X_val,
    y_val,
    w_wt,
    mean,
    std,
    threshold=best_t_wt["t"],
    title=f"Class-weighted logistic regression (threshold t*={best_t_wt['t']:.2f})",
)

fig = make_subplots(rows=1, cols=2, subplot_titles=(fig1.layout.title.text, fig2.layout.title.text))
for tr in fig1.data:
    fig.add_trace(tr, row=1, col=1)
for tr in fig2.data:
    fig.add_trace(tr, row=1, col=2)

fig.update_layout(height=450, width=1050, title="Decision boundary tuned for balanced accuracy")
fig.update_xaxes(title_text="x1", row=1, col=1)
fig.update_yaxes(title_text="x2", row=1, col=1)
fig.update_xaxes(title_text="x1", row=1, col=2)
fig.update_yaxes(title_text="x2", row=1, col=2)
fig.show()

7) Practical scikit-learn usage#

Metric#

from sklearn.metrics import balanced_accuracy_score

balanced_accuracy_score(y_true, y_pred)
balanced_accuracy_score(y_true, y_pred, adjusted=True)

Model selection#

  • In GridSearchCV / cross_val_score, use scoring="balanced_accuracy".

  • Many estimators support class_weight="balanced", which often improves balanced accuracy.

# scikit-learn comparison on the same dataset

clf_unw = LogisticRegression(max_iter=2000)
clf_wt = LogisticRegression(max_iter=2000, class_weight="balanced")

clf_unw.fit(X_train, y_train)
clf_wt.fit(X_train, y_train)

pred_unw = clf_unw.predict(X_val)
pred_wt = clf_wt.predict(X_val)

print('sklearn unweighted BA:', sk_balanced_accuracy_score(y_val, pred_unw))
print('sklearn weighted BA:  ', sk_balanced_accuracy_score(y_val, pred_wt))
sklearn unweighted BA: 0.5483333333333333
sklearn weighted BA:   0.8183333333333334

8) Pros, cons, and when to use it#

Pros#

  • Handles class imbalance better than accuracy (each class contributes equally).

  • Easy to interpret: it is the average recall per class.

  • Works naturally for multiclass problems.

Cons / limitations#

  • Ignores precision: you can increase recall (and BA) by predicting a class more often, possibly creating many false positives.

  • Threshold-dependent: with probabilistic outputs, you may need to tune the decision threshold.

  • Not differentiable → typically used for evaluation/model selection, not as a direct training loss.

  • Equal class weighting may not match real costs (some false negatives/positives may matter more than others).

Good use-cases#

  • Imbalanced classification where you want good recall for every class.

  • Settings where the minority class is important and accuracy would be misleading.

If you care about ranking probabilities rather than hard labels, consider threshold-free metrics such as AUROC or Average Precision (PR AUC).

9) Exercises#

  1. Compute balanced accuracy by hand for a binary confusion matrix.

  2. Implement balanced_accuracy_score_np(..., sample_weight=...) tests:

    • give higher weight to a subset of samples

    • verify it matches sklearn.metrics.balanced_accuracy_score(..., sample_weight=...).

  3. On the synthetic dataset above:

    • compare accuracy vs balanced accuracy as you vary the threshold

    • find a threshold that maximizes balanced accuracy and report the per-class recalls.

  4. Extend the notebook to multiclass:

    • generate 3 classes with imbalance

    • compute per-class recalls and balanced accuracy

    • visualize the confusion matrix.

References#

  • scikit-learn docs: https://scikit-learn.org/stable/modules/generated/sklearn.metrics.balanced_accuracy_score.html

  • scikit-learn user guide (model evaluation): https://scikit-learn.org/stable/modules/model_evaluation.html